import pandas as pd
import matplotlib.pyplot as plt

print("=== Generating Figure 5 (FULL DATASET 2016–2023) ===")

# Sector mapping (unchanged)
sector_map = {
    'NLT': 'Agriculture, forestry and fishing',
    'KKK': 'Mining and quarrying',
    'CBC': 'Manufacturing',
    'SXP': 'Electricity, gas, steam and air conditioning supply',
    'NRT': 'Water supply; sewerage, waste management and remediation activities',
    'XDD': 'Construction',
    'BBL': 'Wholesale and retail trade; repair of motor vehicles and motorcycles',
    'VTK': 'Transportation and storage',
    'LTA': 'Accommodation and food service activities',
    'TTT': 'Information and communication',
    'TNB': 'Financial and insurance activities',
    'BDS': 'Real estate activities',
    'KHC': 'Professional, scientific and technical activities',
    'HCD': 'Administrative and support service activities',
    'GDD': 'Education',
    'YTT': 'Human health and social work activities',
    'VCG': 'Arts, entertainment and recreation',
    'DVK': 'Other service activities'
}

# Load the full dataset
df = pd.read_csv("dataset.csv")
print(f"Dataset loaded: {len(df):,} observations")

# Determine actual period automatically
min_year = int(df['year'].min())
max_year = int(df['year'].max())
period_str = f"({min_year}–{max_year})"
print(f"Using FULL period: {period_str}")

# Map English sector names & compute averages
df['sector_en'] = df['sector_id'].map(sector_map)

sector_wages = (df.groupby(['sector_id', 'sector_en'])['wage_sr_t']
                .mean()
                .round(0)
                .astype(int)
                .reset_index()
                .sort_values('wage_sr_t', ascending=True))

print(f"Computed average wages for {len(sector_wages)} sectors")

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
n = len(sector_wages)
colors = ['red' if i < 5 else 'blue' if i >= n-5 else 'gray' for i in range(n)]
bars = ax.barh(sector_wages['sector_en'], sector_wages['wage_sr_t'],
               color=colors, edgecolor='black')

ax.set_xlabel('Average Monthly Wage (thousand VND)', fontsize=12)
ax.set_title(f'Figure 5: Average Monthly Wages by Sector {period_str}', 
             fontsize=14, pad=20)

for bar in bars:
    width = bar.get_width()
    ax.text(width + 100, bar.get_y() + bar.get_height()/2, 
            f'{int(width):,}', va='center', ha='left', fontsize=10)

ax.invert_yaxis()
ax.grid(axis='x', linestyle='--', alpha=0.7)

plt.tight_layout()

# === SAVE BOTH FORMATS ===
# High-quality TIFF
plt.savefig("figure5_sectoral_wages.tif", 
            dpi=600, 
            format='tiff', 
            pil_kwargs={'compression': 'tiff_lzw'}, 
            bbox_inches='tight')

# High-quality PNG
plt.savefig("figure5_sectoral_wages.png", 
            dpi=600, 
            bbox_inches='tight')

plt.close()

print("\n✅ Figure 5 saved successfully in BOTH formats!")
print("   • figure5_sectoral_wages.tif (600 dpi TIFF LZW)")
print("   • figure5_sectoral_wages.png (600 dpi PNG)")
print(f"   Period used: {period_str}")
print("   (Red = 5 lowest-paying sectors | Blue = 5 highest-paying sectors)")